import sys
import torch
from PIL import Image
import os.path as osp
import warnings
from .base import BaseModel
from ..smp import *
from huggingface_hub import snapshot_download


class OpenFlamingo(BaseModel):

    INSTALL_REQ = True
    INTERLEAVE = True

    def __init__(self, name, mpt_pth=None, ckpt_pth=None, **kwargs):

        if mpt_pth is None:
            raise ValueError(
                "Please set `mpt_pth` to the directory of MPT-7B, which is cloned from here: "
                "https://huggingface.co/mosaicml/mpt-7b. "
            )
            raise ValueError
        if ckpt_pth is None:
            raise ValueError(
                "Please set `ckpt_pth` to the openflamingo ckpt, which is the `checkpoint.pt` file downloaded "
                "from: https://huggingface.co/openflamingo/OpenFlamingo-9B-vitl-mpt7b/tree/main. "
            )
        else:
            if osp.exists(ckpt_pth):
                if ckpt_pth.endswith("checkpoint.pt"):
                    pass
                elif osp.isdir(ckpt_pth):
                    ckpt_pth = osp.join(ckpt_pth, "checkpoint.pt")
                    if not osp.exists(ckpt_pth):
                        raise ValueError(f"File {ckpt_pth} does not exist. ")
            elif splitlen(ckpt_pth, "/") == 2:
                cache_path = get_cache_path(ckpt_pth)
                if cache_path is None:
                    snapshot_download(ckpt_pth)
                cache_path = get_cache_path(ckpt_pth)
                if cache_path is None:
                    raise ValueError(f"Directory {cache_path} does not exist. ")
                else:
                    ckpt_pth = osp.join(cache_path, "checkpoint.pt")

        self.name = name
        assert name in ["v2"]
        self.mpt_pth = mpt_pth
        try:
            from open_flamingo import create_model_and_transforms
        except Exception as e:
            logging.critical("Please first install open_flamingo to use OpenFlamingo")
            raise e

        model, image_processor, tokenizer = create_model_and_transforms(
            clip_vision_encoder_path="ViT-L-14",
            clip_vision_encoder_pretrained="openai",
            lang_encoder_path=mpt_pth,
            tokenizer_path=mpt_pth,
            cross_attn_every_n_layers=4,
        )
        ckpt = torch.load(ckpt_pth)
        model.load_state_dict(ckpt, strict=False)
        torch.cuda.empty_cache()
        self.model = model.eval().cuda()
        self.tokenizer = tokenizer
        self.tokenizer.padding_side = "left"
        self.image_proc = image_processor

        kwargs_default = dict(max_new_tokens=512, num_beams=3)
        kwargs_default.update(kwargs)
        self.kwargs = kwargs_default
        warnings.warn(
            f"Following kwargs received: {self.kwargs}, will use as generation config. "
        )

    def generate_inner(self, message, dataset=None):
        vision_x = []
        prompt = ""
        for msg in message:
            if msg["type"] == "image":
                img = Image.open(msg["value"])
                vision_x.append(self.image_proc(img).unsqueeze(0))
                prompt += "<image>"
            elif msg["type"] == "text":
                prompt += msg["value"]
        prompt += "Answer: "
        vision_x = torch.cat(vision_x, dim=0) if len(vision_x) > 1 else vision_x[0]
        vision_x = vision_x.unsqueeze(1).unsqueeze(0)
        lang_x = self.tokenizer([prompt], return_tensors="pt")
        generated_text = self.model.generate(
            vision_x=vision_x.cuda(),
            lang_x=lang_x["input_ids"].cuda(),
            attention_mask=lang_x["attention_mask"].cuda(),
            **self.kwargs,
        )
        generated_text = self.tokenizer.decode(generated_text[0])
        text = generated_text[len(prompt) :].split("<|endofchunk|>")[0]
        return text
